{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "deletable": true,
    "editable": true
   },
   "source": [
    "# Auxiliary Classifier Generative Adversarial Networks (AC-GAN) on MNIST\n",
    "\n",
    "Modified from Keras examples:\n",
    "\n",
    "https://github.com/fchollet/keras/blob/master/examples/mnist_acgan.py\n",
    "\n",
    "Original repo: https://github.com/lukedeo/keras-acgan\n",
    "\n",
    "Here, we use the tensorflow backend. The learning rate is decreased."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n"
     ]
    }
   ],
   "source": [
    "from PIL import Image\n",
    "from keras.datasets import mnist\n",
    "from keras.layers import Input, Dense, Reshape, Flatten, Embedding, Dropout\n",
    "from keras.layers import Multiply, LeakyReLU, UpSampling2D, Conv2D\n",
    "from keras.models import Sequential, Model\n",
    "from keras.optimizers import Adam\n",
    "import numpy as np\n",
    "\n",
    "np.random.seed(1337)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true,
    "deletable": true,
    "editable": true
   },
   "outputs": [],
   "source": [
    "def build_generator(latent_size):\n",
    "    # we will map a pair of (z, L), where z is a latent vector and L is a\n",
    "    # label drawn from P_c, to image space (..., 28, 28, 1)\n",
    "    cnn = Sequential()\n",
    "\n",
    "    cnn.add(Dense(1024, input_dim=latent_size, activation='relu'))\n",
    "    cnn.add(Dense(128 * 7 * 7, activation='relu'))\n",
    "    cnn.add(Reshape((7, 7, 128)))\n",
    "\n",
    "    # upsample to (..., 14, 14)\n",
    "    cnn.add(UpSampling2D(size=(2, 2)))\n",
    "    cnn.add(Conv2D(256, 5, padding='same', activation='relu',\n",
    "                   kernel_initializer='glorot_normal'))\n",
    "\n",
    "    # upsample to (..., 28, 28)\n",
    "    cnn.add(UpSampling2D(size=(2, 2)))\n",
    "    cnn.add(Conv2D(128, 5, padding='same', activation='relu',\n",
    "                   kernel_initializer='glorot_normal'))\n",
    "\n",
    "    # take a channel axis reduction\n",
    "    cnn.add(Conv2D(1, 2, padding='same', activation='tanh',\n",
    "                   kernel_initializer='glorot_normal'))\n",
    "\n",
    "    # this is the z space commonly refered to in GAN papers\n",
    "    latent = Input(shape=(latent_size,))\n",
    "\n",
    "    # this will be our label\n",
    "    image_class = Input(shape=(1,), dtype='int32')\n",
    "\n",
    "    # 10 classes in MNIST\n",
    "    emb = Embedding(10, latent_size, embeddings_initializer='glorot_normal')(image_class)\n",
    "    cls = Flatten()(emb)\n",
    "\n",
    "    # hadamard product between z-space and a class conditional embedding\n",
    "    h = Multiply()([latent, cls])\n",
    "\n",
    "    fake_image = cnn(h)\n",
    "\n",
    "    return Model([latent, image_class], fake_image)\n",
    "\n",
    "\n",
    "def build_discriminator():\n",
    "    # build a relatively standard conv net, with LeakyReLUs as suggested in\n",
    "    # the reference paper\n",
    "    cnn = Sequential()\n",
    "\n",
    "    cnn.add(Conv2D(32, 3, padding='same', strides=2, input_shape=(28, 28, 1)))\n",
    "    cnn.add(LeakyReLU())\n",
    "    cnn.add(Dropout(0.3))\n",
    "\n",
    "    cnn.add(Conv2D(64, 3, padding='same', strides=1))\n",
    "    cnn.add(LeakyReLU())\n",
    "    cnn.add(Dropout(0.3))\n",
    "\n",
    "    cnn.add(Conv2D(128, 3, padding='same', strides=2))\n",
    "    cnn.add(LeakyReLU())\n",
    "    cnn.add(Dropout(0.3))\n",
    "\n",
    "    cnn.add(Conv2D(256, 3, padding='same', strides=1))\n",
    "    cnn.add(LeakyReLU())\n",
    "    cnn.add(Dropout(0.3))\n",
    "\n",
    "    cnn.add(Flatten())\n",
    "\n",
    "    image = Input(shape=(28, 28, 1))\n",
    "\n",
    "    features = cnn(image)\n",
    "\n",
    "    # first output (name=generation) is whether or not the discriminator\n",
    "    # thinks the image that is being shown is fake, and the second output\n",
    "    # (name=auxiliary) is the class that the discriminator thinks the image\n",
    "    # belongs to.\n",
    "    fake = Dense(1, activation='sigmoid', name='generation')(features)\n",
    "    aux = Dense(10, activation='softmax', name='auxiliary')(features)\n",
    "\n",
    "    return Model(image, [fake, aux])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true
   },
   "outputs": [],
   "source": [
    "# batch and latent size taken from the paper\n",
    "epochs = 50\n",
    "batch_size = 100\n",
    "latent_size = 100\n",
    "\n",
    "# Adam parameters suggested in https://arxiv.org/abs/1511.06434\n",
    "# decreased learning rate from repo settings\n",
    "adam_lr = 0.00005\n",
    "adam_beta_1 = 0.5\n",
    "\n",
    "# build the discriminator\n",
    "discriminator = build_discriminator()\n",
    "discriminator.compile(\n",
    "    optimizer=Adam(lr=adam_lr, beta_1=adam_beta_1),\n",
    "    loss=['binary_crossentropy', 'sparse_categorical_crossentropy']\n",
    ")\n",
    "\n",
    "# build the generator\n",
    "generator = build_generator(latent_size)\n",
    "generator.compile(\n",
    "    optimizer=Adam(lr=adam_lr, beta_1=adam_beta_1),\n",
    "    loss='binary_crossentropy'\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "____________________________________________________________________________________________________\n",
      "Layer (type)                     Output Shape          Param #     Connected to                     \n",
      "====================================================================================================\n",
      "input_3 (InputLayer)             (None, 1)             0                                            \n",
      "____________________________________________________________________________________________________\n",
      "embedding_1 (Embedding)          (None, 1, 100)        1000        input_3[0][0]                    \n",
      "____________________________________________________________________________________________________\n",
      "input_2 (InputLayer)             (None, 100)           0                                            \n",
      "____________________________________________________________________________________________________\n",
      "flatten_2 (Flatten)              (None, 100)           0           embedding_1[0][0]                \n",
      "____________________________________________________________________________________________________\n",
      "multiply_1 (Multiply)            (None, 100)           0           input_2[0][0]                    \n",
      "                                                                   flatten_2[0][0]                  \n",
      "____________________________________________________________________________________________________\n",
      "sequential_2 (Sequential)        (None, 28, 28, 1)     8171521     multiply_1[0][0]                 \n",
      "====================================================================================================\n",
      "Total params: 8,172,521\n",
      "Trainable params: 8,172,521\n",
      "Non-trainable params: 0\n",
      "____________________________________________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "generator.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true,
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "____________________________________________________________________________________________________\n",
      "Layer (type)                     Output Shape          Param #     Connected to                     \n",
      "====================================================================================================\n",
      "input_1 (InputLayer)             (None, 28, 28, 1)     0                                            \n",
      "____________________________________________________________________________________________________\n",
      "sequential_1 (Sequential)        (None, 12544)         387840      input_1[0][0]                    \n",
      "____________________________________________________________________________________________________\n",
      "generation (Dense)               (None, 1)             12545       sequential_1[1][0]               \n",
      "____________________________________________________________________________________________________\n",
      "auxiliary (Dense)                (None, 10)            125450      sequential_1[1][0]               \n",
      "====================================================================================================\n",
      "Total params: 525,835\n",
      "Trainable params: 525,835\n",
      "Non-trainable params: 0\n",
      "____________________________________________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "discriminator.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": true,
    "deletable": true,
    "editable": true
   },
   "outputs": [],
   "source": [
    "latent = Input(shape=(latent_size, ))\n",
    "image_class = Input(shape=(1,), dtype='int32')\n",
    "\n",
    "# get a fake image\n",
    "fake = generator([latent, image_class])\n",
    "\n",
    "# we only want to be able to train generation for the combined model\n",
    "discriminator.trainable = False\n",
    "fake, aux = discriminator(fake)\n",
    "combined = Model([latent, image_class], [fake, aux])\n",
    "\n",
    "combined.compile(\n",
    "    optimizer=Adam(lr=adam_lr, beta_1=adam_beta_1),\n",
    "    loss=['binary_crossentropy', 'sparse_categorical_crossentropy']\n",
    ")\n",
    "\n",
    "# get our mnist data, and force it to be of shape (..., 28, 28, 1) with\n",
    "# range [-1, 1]\n",
    "(X_train, y_train), (X_test, y_test) = mnist.load_data()\n",
    "X_train = (X_train.astype(np.float32) - 127.5) / 127.5\n",
    "X_train = np.expand_dims(X_train, axis=3)\n",
    "X_test = (X_test.astype(np.float32) - 127.5) / 127.5\n",
    "X_test = np.expand_dims(X_test, axis=3)\n",
    "\n",
    "num_train, num_test = X_train.shape[0], X_test.shape[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "deletable": true,
    "editable": true
   },
   "source": [
    "### training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch\tL_s(G)\tL_s(G)\tL_s(D)\tL_s(D)\tL_c(G)\tL_c(G)\tL_c(D)\tL_c(D)\n",
      "1\t1.5043\t1.3493\t0.4646\t0.6713\t2.3628\t2.2071\t1.7203\t1.4392\n",
      "2\t1.3244\t1.1152\t0.5762\t0.5989\t1.1588\t0.2964\t0.8854\t0.3811\n",
      "3\t1.2463\t1.2954\t0.5824\t0.4954\t0.2412\t0.1020\t0.3706\t0.2643\n",
      "4\t1.1428\t1.3709\t0.6024\t0.4410\t0.1381\t0.0566\t0.2731\t0.1755\n",
      "5\t1.4719\t0.8972\t0.4950\t0.5593\t0.1491\t0.0481\t0.2647\t0.2130\n",
      "6\t1.2936\t1.3587\t0.5415\t0.4765\t0.0782\t0.0156\t0.1912\t0.1147\n",
      "7\t1.1821\t0.7369\t0.5868\t0.6817\t0.0459\t0.0315\t0.1445\t0.1050\n",
      "8\t1.0301\t0.9090\t0.6196\t0.5273\t0.0361\t0.0220\t0.1232\t0.0847\n",
      "9\t0.9657\t0.8546\t0.6343\t0.5678\t0.0208\t0.0023\t0.1014\t0.0647\n",
      "10\t0.9368\t0.7325\t0.6338\t0.6280\t0.0208\t0.0038\t0.0936\t0.0608\n",
      "11\t0.9013\t0.7860\t0.6410\t0.6273\t0.0181\t0.0048\t0.0844\t0.0550\n",
      "12\t0.9035\t0.6963\t0.6413\t0.6403\t0.0141\t0.0035\t0.0766\t0.0516\n",
      "13\t0.8247\t0.7233\t0.6735\t0.6877\t0.0132\t0.0071\t0.0721\t0.0478\n",
      "14\t0.8030\t0.8499\t0.6727\t0.6389\t0.0127\t0.0113\t0.0687\t0.0484\n",
      "15\t0.8020\t0.6725\t0.6819\t0.6773\t0.0096\t0.0027\t0.0619\t0.0394\n",
      "16\t0.7598\t0.6672\t0.6929\t0.7116\t0.0081\t0.0017\t0.0593\t0.0355\n",
      "17\t0.7324\t0.6939\t0.6998\t0.6806\t0.0080\t0.0131\t0.0551\t0.0392\n",
      "18\t0.7285\t0.6936\t0.6986\t0.6958\t0.0057\t0.0076\t0.0510\t0.0346\n",
      "19\t0.7172\t0.6839\t0.7034\t0.6970\t0.0051\t0.0026\t0.0489\t0.0305\n",
      "20\t0.7179\t0.7293\t0.7003\t0.6847\t0.0056\t0.0034\t0.0464\t0.0292\n",
      "21\t0.7154\t0.7009\t0.7014\t0.7013\t0.0049\t0.0007\t0.0449\t0.0269\n",
      "22\t0.7151\t0.7358\t0.7017\t0.6810\t0.0047\t0.0016\t0.0430\t0.0258\n",
      "23\t0.7147\t0.7004\t0.7013\t0.7053\t0.0045\t0.0021\t0.0413\t0.0251\n",
      "24\t0.7127\t0.6742\t0.7021\t0.7056\t0.0039\t0.0010\t0.0396\t0.0242\n",
      "25\t0.7095\t0.6841\t0.7018\t0.7270\t0.0035\t0.0036\t0.0379\t0.0242\n",
      "26\t0.7078\t0.6882\t0.7023\t0.7166\t0.0036\t0.0015\t0.0373\t0.0223\n",
      "27\t0.7087\t0.6954\t0.7016\t0.7232\t0.0033\t0.0013\t0.0356\t0.0213\n",
      "28\t0.7082\t0.6766\t0.7019\t0.7195\t0.0033\t0.0007\t0.0342\t0.0210\n",
      "29\t0.7101\t0.6731\t0.7004\t0.7143\t0.0032\t0.0003\t0.0335\t0.0202\n",
      "30\t0.7066\t0.6885\t0.7010\t0.7175\t0.0029\t0.0006\t0.0320\t0.0196\n",
      "31\t0.7076\t0.6805\t0.7001\t0.7174\t0.0026\t0.0003\t0.0310\t0.0193\n",
      "32\t0.7064\t0.7104\t0.7002\t0.7046\t0.0026\t0.0008\t0.0312\t0.0189\n",
      "33\t0.7063\t0.7012\t0.6996\t0.7042\t0.0026\t0.0003\t0.0298\t0.0187\n",
      "34\t0.7061\t0.7042\t0.6994\t0.7128\t0.0024\t0.0006\t0.0289\t0.0184\n",
      "35\t0.7058\t0.6985\t0.6992\t0.7134\t0.0022\t0.0003\t0.0279\t0.0174\n",
      "36\t0.7057\t0.6972\t0.6991\t0.7206\t0.0023\t0.0005\t0.0276\t0.0170\n",
      "37\t0.7079\t0.7099\t0.6986\t0.6960\t0.0024\t0.0001\t0.0272\t0.0169\n",
      "38\t0.7058\t0.7233\t0.6980\t0.6988\t0.0021\t0.0001\t0.0262\t0.0165\n",
      "39\t0.7060\t0.6962\t0.6982\t0.7057\t0.0021\t0.0002\t0.0260\t0.0163\n",
      "40\t0.7049\t0.6855\t0.6979\t0.7017\t0.0020\t0.0002\t0.0250\t0.0159\n",
      "41\t0.7043\t0.7171\t0.6985\t0.7046\t0.0020\t0.0002\t0.0250\t0.0155\n",
      "42\t0.7045\t0.6962\t0.6977\t0.7028\t0.0019\t0.0002\t0.0248\t0.0156\n",
      "43\t0.7044\t0.6829\t0.6979\t0.7125\t0.0018\t0.0001\t0.0235\t0.0161\n",
      "44\t0.7037\t0.6832\t0.6979\t0.7089\t0.0017\t0.0002\t0.0234\t0.0153\n",
      "45\t0.7044\t0.6833\t0.6970\t0.7133\t0.0018\t0.0002\t0.0239\t0.0148\n",
      "46\t0.7043\t0.6956\t0.6973\t0.7078\t0.0017\t0.0001\t0.0227\t0.0147\n",
      "47\t0.7044\t0.6855\t0.6971\t0.7026\t0.0016\t0.0003\t0.0225\t0.0143\n",
      "48\t0.7031\t0.6789\t0.6971\t0.7147\t0.0017\t0.0001\t0.0227\t0.0146\n",
      "49\t0.7041\t0.6814\t0.6972\t0.6942\t0.0017\t0.0002\t0.0224\t0.0144\n",
      "50\t0.7039\t0.7009\t0.6966\t0.7032\t0.0015\t0.0003\t0.0214\t0.0140\n",
      "done.\n"
     ]
    }
   ],
   "source": [
    "print('Epoch\\tL_s(G)\\tL_s(G)\\tL_s(D)\\tL_s(D)\\tL_c(G)\\tL_c(G)\\tL_c(D)\\tL_c(D)')\n",
    "for epoch in range(epochs):\n",
    "    print(epoch + 1, end='\\t', flush=True)\n",
    "\n",
    "    num_batches = int(X_train.shape[0] / batch_size)\n",
    "\n",
    "    epoch_gen_loss = []\n",
    "    epoch_disc_loss = []\n",
    "\n",
    "    for index in range(num_batches):\n",
    "        # generate a new batch of noise\n",
    "        noise = np.random.uniform(-1, 1, (batch_size, latent_size))\n",
    "\n",
    "        # get a batch of real images\n",
    "        image_batch = X_train[index * batch_size:(index + 1) * batch_size]\n",
    "        label_batch = y_train[index * batch_size:(index + 1) * batch_size]\n",
    "\n",
    "        # sample some labels from p_c\n",
    "        sampled_labels = np.random.randint(0, 10, batch_size)\n",
    "\n",
    "        # generate a batch of fake images, using the generated labels as a\n",
    "        # conditioner. We reshape the sampled labels to be\n",
    "        # (batch_size, 1) so that we can feed them into the embedding\n",
    "        # layer as a length one sequence\n",
    "        generated_images = generator.predict(\n",
    "            [noise, sampled_labels.reshape((-1, 1))], verbose=0)\n",
    "\n",
    "        X = np.concatenate((image_batch, generated_images))\n",
    "        y = np.array([1] * batch_size + [0] * batch_size)\n",
    "        aux_y = np.concatenate((label_batch, sampled_labels), axis=0)\n",
    "\n",
    "        # see if the discriminator can figure itself out...\n",
    "        epoch_disc_loss.append(discriminator.train_on_batch(X, [y, aux_y]))\n",
    "\n",
    "        # make new noise. we generate 2 * batch size here such that we have\n",
    "        # the generator optimize over an identical number of images as the\n",
    "        # discriminator\n",
    "        noise = np.random.uniform(-1, 1, (2 * batch_size, latent_size))\n",
    "        sampled_labels = np.random.randint(0, 10, 2 * batch_size)\n",
    "\n",
    "        # we want to train the generator to trick the discriminator\n",
    "        # For the generator, we want all the {fake, not-fake} labels to say\n",
    "        # not-fake\n",
    "        trick = np.ones(2 * batch_size)\n",
    "\n",
    "        epoch_gen_loss.append(\n",
    "            combined.train_on_batch(\n",
    "                [noise, sampled_labels.reshape((-1, 1))],\n",
    "                [trick, sampled_labels]\n",
    "            )\n",
    "        )\n",
    "\n",
    "    # evaluate the testing loss here\n",
    "\n",
    "    # generate a new batch of noise\n",
    "    noise = np.random.uniform(-1, 1, (num_test, latent_size))\n",
    "\n",
    "    # sample some labels from p_c and generate images from them\n",
    "    sampled_labels = np.random.randint(0, 10, num_test)\n",
    "    generated_images = generator.predict(\n",
    "        [noise, sampled_labels.reshape((-1, 1))], verbose=0)\n",
    "\n",
    "    X = np.concatenate((X_test, generated_images))\n",
    "    y = np.array([1] * num_test + [0] * num_test)\n",
    "    aux_y = np.concatenate((y_test, sampled_labels), axis=0)\n",
    "\n",
    "    # see if the discriminator can figure itself out...\n",
    "    discriminator_test_loss = discriminator.evaluate(X, [y, aux_y], verbose=0)\n",
    "\n",
    "    discriminator_train_loss = np.mean(np.array(epoch_disc_loss), axis=0)\n",
    "\n",
    "    # make new noise\n",
    "    noise = np.random.uniform(-1, 1, (2 * num_test, latent_size))\n",
    "    sampled_labels = np.random.randint(0, 10, 2 * num_test)\n",
    "\n",
    "    trick = np.ones(2 * num_test)\n",
    "\n",
    "    generator_test_loss = combined.evaluate(\n",
    "        [noise, sampled_labels.reshape((-1, 1))],\n",
    "        [trick, sampled_labels],\n",
    "        verbose=0\n",
    "    )\n",
    "\n",
    "    generator_train_loss = np.mean(np.array(epoch_gen_loss), axis=0)\n",
    "    \n",
    "    print('{:.4f}\\t{:.4f}\\t{:.4f}\\t{:.4f}\\t{:.4f}\\t{:.4f}\\t{:.4f}\\t{:.4f}'.format(\n",
    "        # generation loss\n",
    "        generator_train_loss[1], generator_test_loss[1],\n",
    "        discriminator_train_loss[1], discriminator_test_loss[1],\n",
    "        # auxillary loss\n",
    "        generator_train_loss[2], generator_test_loss[2],\n",
    "        discriminator_train_loss[2], discriminator_test_loss[2],\n",
    "    ))\n",
    "\n",
    "    # save weights every epoch\n",
    "    generator.save_weights('mnist_acgan/mnist_acgan_g_weights.hdf5', True)\n",
    "    discriminator.save_weights('mnist_acgan/mnist_acgan_d_weights.hdf5', True)\n",
    "\n",
    "    # generate some digits to display\n",
    "    noise = np.random.uniform(-1, 1, (100, latent_size))\n",
    "\n",
    "    sampled_labels = np.array([[i] * 10 for i in range(10)]).reshape(-1, 1)\n",
    "\n",
    "    # get a batch to display\n",
    "    generated_images = generator.predict([noise, sampled_labels], verbose=0)\n",
    "\n",
    "    # arrange them into a grid\n",
    "    img = (np.concatenate([r.reshape(-1, 28)\n",
    "                           for r in np.split(generated_images, 10)\n",
    "                           ], axis=-1) * 127.5 + 127.5).astype(np.uint8)\n",
    "\n",
    "    Image.fromarray(img).save('mnist_acgan/mnist_acgan_generated_{0:03d}.png'.format(epoch))\n",
    "    \n",
    "print('done.')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "deletable": true,
    "editable": true
   },
   "source": [
    "### save model architecture"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "collapsed": true,
    "deletable": true,
    "editable": true
   },
   "outputs": [],
   "source": [
    "with open('mnist_acgan.json', 'w') as f:\n",
    "    f.write(generator.to_json())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {
    "collapsed": true,
    "deletable": true,
    "editable": true
   },
   "outputs": [],
   "source": [
    "!cp mnist_acgan/mnist_acgan_g_weights.hdf5 ./mnist_acgan.hdf5"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "deletable": true,
    "editable": true
   },
   "source": [
    "### view generated history"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "deletable": true,
    "editable": true
   },
   "source": [
    "epoch 50:"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true,
    "deletable": true,
    "editable": true
   },
   "source": [
    "![title](mnist_acgan/mnist_acgan_generated_049.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "deletable": true,
    "editable": true
   },
   "source": [
    "### generate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "collapsed": true,
    "deletable": true,
    "editable": true
   },
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "collapsed": true,
    "deletable": true,
    "editable": true
   },
   "outputs": [],
   "source": [
    "def make_digit(digit=None):\n",
    "    noise = np.random.uniform(-1, 1, (1, latent_size))\n",
    "\n",
    "    sampled_label = np.array([\n",
    "            digit if digit is not None else np.random.randint(0, 10, 1)\n",
    "        ]).reshape(-1, 1)\n",
    "\n",
    "    generated_image = generator.predict(\n",
    "        [noise, sampled_label], verbose=0)\n",
    "\n",
    "    return np.squeeze((generated_image * 127.5 + 127.5).astype(np.uint8))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(-0.5, 27.5, 27.5, -0.5)"
      ]
     },
     "execution_count": 33,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAB7lJREFUeJzt3c+rj/kfxvH74GCOyZFpOiQxReokWdDYzEpZyI8SzWIa\nYhbKQulM2dqQf8CSkoWFhY0F/4AodM5wyAJhI9RQfg/nfDff7f26Nec4w7kej+019znHjOfci/e5\n70/P+Ph4A+SZ8V//AMB/Q/wQSvwQSvwQSvwQSvwQSvwQSvwQSvwQatYUfz+/TghfXs/n/EPu/BBK\n/BBK/BBK/BBK/BBK/BBK/BBK/BBK/BBK/BBK/BBK/BBK/BBK/BBK/BBqqp/nh8/26dOncp8xo753\n9fR81mPtsdz5IZT4IZT4IZT4IZT4IZT4IZSjPr6ojx8/tm43btwor92zZ0+5z5pV//W9evVq69bX\n11dem8CdH0KJH0KJH0KJH0KJH0KJH0KJH0L1jI9P6adm+4jub0zX348XL16U+9DQUOt2+vTp8tqx\nsbFy73pkt/o9gZMnT5bXdj0u/JXzEd1AO/FDKPFDKPFDKPFDKPFDKPFDKM/zh+s6x3/w4EG5Hzly\npNzPnj3bunWd43edtff395d79TsIz549K68dGBgo9+nAnR9CiR9CiR9CiR9CiR9CiR9CiR9COecP\nV71Xv2ma5vbt2+V+4cKFcq8+ZnvOnDnltfPnzy/3rnfv9/b2tm5Pnjwpr3XOD0xb4odQ4odQ4odQ\n4odQ4odQXt0d7uHDh+W+Y8eOch8eHi732bNnt25Lly4tr920aVO5//HHH+W+evXq1q06BpwGvLob\naCd+CCV+CCV+CCV+CCV+CCV+COWR3mnuzp075b5v375y/+uvv8q9OsdvmqbZsmVL67Zr167y2s2b\nN5f7999/X+7U3PkhlPghlPghlPghlPghlPghlPghlHP+aeDmzZut2+7du8trb926Ve7Vq7ebpml+\n+eWXcv/zzz9bt+p5+6Zpmnnz5pU7E+POD6HED6HED6HED6HED6HED6HED6Gc838DXrx4Ue6//fZb\n6zY6OlpeOzY2Vu5dZ+2Dg4PlvmDBgtat610AfFnu/BBK/BBK/BBK/BBK/BBK/BBK/BDKOf9XoOus\n/ciRI+VePZM/Pj5eXjtjRv3//x9//LHcFy1aVO6PHj1q3QYGBsprq98RYOLc+SGU+CGU+CGU+CGU\n+CGU+CFUT9dR0CSb0m/2rXjy5Em5r1mzptyfPXvWunUd5Z05c6bcf/rpp3KfNas+Lf7uu+9at4UL\nF5bXLl68uNx7enrKPdhn/Ytx54dQ4odQ4odQ4odQ4odQ4odQ4odQHun9Cly7dq3c//7773/9tTdu\n3Fjuu3btKvfqdwiapmkuXbpU7nfv3m3dHj58WF57/Pjxcl+2bFm5U3Pnh1Dih1Dih1Dih1Dih1Di\nh1Dih1DO+afAq1evyv3AgQPl/vHjx3KvnmtfsWJFee3w8HC5Hzt2rNwvX75c7m/fvm3d3r9/X17b\n9bNduXKl3Pv7+8s9nTs/hBI/hBI/hBI/hBI/hBI/hBI/hPLe/inQdV69YcOGcu86D6/ejb9q1ary\n2gcPHpT7y5cvy71L13v9K12fOTA0NFTuR48ebd2m+Tv/vbcfaCd+CCV+CCV+CCV+CCV+CCV+COWc\nfwqcOnWq3Pfv31/uXc/z9/b2tm5d/327vnaXBQsWlPv58+dbt5GRkfLaw4cPl/vy5cvL/fr1663b\nvHnzymu/cc75gXbih1Dih1Dih1Dih1Dih1Be3T0Juo7TLl68WO4TPW6rvn9fX195bdcjv+fOnSv3\npUuXlnv1WO7g4GB5bddR39OnT8u9+mjzaX7U91nc+SGU+CGU+CGU+CGU+CGU+CGU+CGUc/5J0HVO\n3/Xq7i5dr5n+/fffW7cTJ06U186ZM2dC33tsbKzc792717pt3bq1vLbrleXz588v94m8NjyBOz+E\nEj+EEj+EEj+EEj+EEj+EEj+EchA6CT59+lTuE31ev+u8+tdff23d5s6dO6Hv3fVnGx0dLfft27e3\nbo8fPy6v7fqI7h9++KHcJ/pnn+7c+SGU+CGU+CGU+CGU+CGU+CGU+CGUc/5J8ObNmy/69f/5559y\n37lzZ+vW9cx812cOXLhwodxfv379r79+1+8vrF27ttwPHTpU7l0fH57OnR9CiR9CiR9CiR9CiR9C\niR9C9XQd9UyyKf1mU6Xrkd1169aV+8jIyGT+OF+V6rHavXv3ltcePHiw3FeuXFnuM2fOLPdprH7f\n+v+580Mo8UMo8UMo8UMo8UMo8UMo8UMo5/xT4OnTp+X+888/l/vz58/L/d27d61b10dodz1W29fX\nV+7Vx4M3TdNs27atdVu/fn15bX9/f7nTyjk/0E78EEr8EEr8EEr8EEr8EEr8EMo5/zfgw4cP5X7/\n/v3W7eXLl+W1XWfpvb295b5kyZJ/fX3w8/ZfmnN+oJ34IZT4IZT4IZT4IZT4IZT4IZRzfph+nPMD\n7cQPocQPocQPocQPocQPocQPocQPocQPocQPocQPocQPocQPocQPocQPocQPocQPocQPocQPocQP\nocQPocQPocQPocQPocQPocQPocQPocQPocQPocQPoWZN8ff7rI8OBr48d34IJX4IJX4IJX4IJX4I\nJX4IJX4IJX4IJX4IJX4IJX4IJX4IJX4IJX4IJX4IJX4IJX4IJX4IJX4IJX4IJX4IJX4IJX4I9T9p\nfmEUrkCMggAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<matplotlib.figure.Figure at 0x7f749ecc57f0>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.imshow(make_digit(digit=6), cmap='gray_r', interpolation='nearest')\n",
    "plt.axis('off')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true,
    "deletable": true,
    "editable": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
